#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import sys
import json
import filecmp
import shutil
import argparse


class Generator(object):

    implementationContent = ""

    RefClades = {
        "DeclarationNameInfo",
        "NestedNameSpecifierLoc",
        "TemplateArgumentLoc",
        "TypeLoc",
    }

    def __init__(self, templateClasses):
        self.templateClasses = templateClasses

    def GeneratePrologue(self):

        self.implementationContent += r"""
/*===- Generated file -------------------------------------------*- C++ -*-===*\
|*                                                                            *|
|* Introspection of available AST node SourceLocations                        *|
|*                                                                            *|
|* Automatically generated file, do not edit!                                 *|
|*                                                                            *|
\*===----------------------------------------------------------------------===*/

namespace clang {
namespace tooling {

using LocationAndString = SourceLocationMap::value_type;
using RangeAndString = SourceRangeMap::value_type;

bool NodeIntrospection::hasIntrospectionSupport() { return true; }

struct RecursionPopper
{
    RecursionPopper(std::vector<clang::TypeLoc> &TypeLocRecursionGuard)
    :  TLRG(TypeLocRecursionGuard)
    {

    }

    ~RecursionPopper()
    {
    TLRG.pop_back();
    }

private:
std::vector<clang::TypeLoc> &TLRG;
};
"""

    def GenerateBaseGetLocationsDeclaration(self, CladeName):
        InstanceDecoration = "*"
        if CladeName in self.RefClades:
            InstanceDecoration = "&"

        self.implementationContent += """
void GetLocationsImpl(SharedLocationCall const& Prefix,
    clang::{0} const {1}Object, SourceLocationMap &Locs,
    SourceRangeMap &Rngs,
    std::vector<clang::TypeLoc> &TypeLocRecursionGuard);
""".format(
            CladeName, InstanceDecoration
        )

    def GenerateSrcLocMethod(self, ClassName, ClassData, CreateLocalRecursionGuard):

        NormalClassName = ClassName
        RecursionGuardParam = (
            ""
            if CreateLocalRecursionGuard
            else ", std::vector<clang::TypeLoc>& TypeLocRecursionGuard"
        )

        if "templateParms" in ClassData:
            TemplatePreamble = "template <typename "
            ClassName += "<"
            First = True
            for TA in ClassData["templateParms"]:
                if not First:
                    ClassName += ", "
                    TemplatePreamble += ", typename "

                First = False
                ClassName += TA
                TemplatePreamble += TA

            ClassName += ">"
            TemplatePreamble += ">\n"
            self.implementationContent += TemplatePreamble

        self.implementationContent += """
static void GetLocations{0}(SharedLocationCall const& Prefix,
    clang::{1} const &Object,
    SourceLocationMap &Locs, SourceRangeMap &Rngs {2})
{{
""".format(
            NormalClassName, ClassName, RecursionGuardParam
        )

        if "sourceLocations" in ClassData:
            for locName in ClassData["sourceLocations"]:
                self.implementationContent += """
  Locs.insert(LocationAndString(Object.{0}(),
    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}")));
""".format(
                    locName
                )

            self.implementationContent += "\n"

        if "sourceRanges" in ClassData:
            for rngName in ClassData["sourceRanges"]:
                self.implementationContent += """
  Rngs.insert(RangeAndString(Object.{0}(),
    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}")));
""".format(
                    rngName
                )

            self.implementationContent += "\n"

        if (
            "typeLocs" in ClassData
            or "typeSourceInfos" in ClassData
            or "nestedNameLocs" in ClassData
            or "declNameInfos" in ClassData
        ):
            if CreateLocalRecursionGuard:
                self.implementationContent += (
                    "std::vector<clang::TypeLoc> TypeLocRecursionGuard;\n"
                )

            self.implementationContent += "\n"

            if "typeLocs" in ClassData:
                for typeLoc in ClassData["typeLocs"]:

                    self.implementationContent += """
              if (Object.{0}()) {{
                GetLocationsImpl(
                    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}"),
                    Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
                }}
              """.format(
                        typeLoc
                    )

            self.implementationContent += "\n"
            if "typeSourceInfos" in ClassData:
                for tsi in ClassData["typeSourceInfos"]:
                    self.implementationContent += """
              if (Object.{0}()) {{
                GetLocationsImpl(llvm::makeIntrusiveRefCnt<LocationCall>(
                    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}",
                        LocationCall::ReturnsPointer), "getTypeLoc"),
                    Object.{0}()->getTypeLoc(), Locs, Rngs, TypeLocRecursionGuard);
                    }}
              """.format(
                        tsi
                    )

                self.implementationContent += "\n"

            if "nestedNameLocs" in ClassData:
                for NN in ClassData["nestedNameLocs"]:
                    self.implementationContent += """
              if (Object.{0}())
                GetLocationsImpl(
                    llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}"),
                    Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
              """.format(
                        NN
                    )

            if "declNameInfos" in ClassData:
                for declName in ClassData["declNameInfos"]:

                    self.implementationContent += """
                      GetLocationsImpl(
                          llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "{0}"),
                          Object.{0}(), Locs, Rngs, TypeLocRecursionGuard);
                      """.format(
                        declName
                    )

        self.implementationContent += "}\n"

    def GenerateFiles(self, OutputFile):
        with open(os.path.join(os.getcwd(), OutputFile), "w") as f:
            f.write(self.implementationContent)

    def GenerateBaseGetLocationsFunction(
        self,
        ASTClassNames,
        ClassEntries,
        CladeName,
        InheritanceMap,
        CreateLocalRecursionGuard,
    ):

        MethodReturnType = "NodeLocationAccessors"
        InstanceDecoration = "*"
        if CladeName in self.RefClades:
            InstanceDecoration = "&"

        Signature = "GetLocations(clang::{0} const {1}Object)".format(
            CladeName, InstanceDecoration
        )
        ImplSignature = """
    GetLocationsImpl(SharedLocationCall const& Prefix,
        clang::{0} const {1}Object, SourceLocationMap &Locs,
        SourceRangeMap &Rngs,
        std::vector<clang::TypeLoc> &TypeLocRecursionGuard)
    """.format(
            CladeName, InstanceDecoration
        )

        self.implementationContent += "void {0} {{ ".format(ImplSignature)

        if CladeName == "TypeLoc":
            self.implementationContent += "if (Object.isNull()) return;"

            self.implementationContent += """
            if (llvm::find(TypeLocRecursionGuard, Object) != TypeLocRecursionGuard.end())
              return;
            TypeLocRecursionGuard.push_back(Object);
            RecursionPopper RAII(TypeLocRecursionGuard);
                """

        RecursionGuardParam = ""
        if not CreateLocalRecursionGuard:
            RecursionGuardParam = ", TypeLocRecursionGuard"

        ArgPrefix = "*"
        if CladeName in self.RefClades:
            ArgPrefix = ""
        self.implementationContent += (
            "GetLocations{0}(Prefix, {1}Object, Locs, Rngs {2});".format(
                CladeName, ArgPrefix, RecursionGuardParam
            )
        )

        if CladeName == "TypeLoc":
            self.implementationContent += """
        if (auto QTL = Object.getAs<clang::QualifiedTypeLoc>()) {
            auto Dequalified = QTL.getNextTypeLoc();
            return GetLocationsImpl(llvm::makeIntrusiveRefCnt<LocationCall>(Prefix, "getNextTypeLoc"),
                                Dequalified,
                                Locs,
                                Rngs,
                                TypeLocRecursionGuard);
        }"""

        for ASTClassName in ASTClassNames:
            if ASTClassName in self.templateClasses:
                continue
            if ASTClassName == CladeName:
                continue
            if CladeName != "TypeLoc":
                self.implementationContent += """
if (auto Derived = llvm::dyn_cast<clang::{0}>(Object)) {{
  GetLocations{0}(Prefix, *Derived, Locs, Rngs {1});
}}
""".format(
                    ASTClassName, RecursionGuardParam
                )
                continue

            self.GenerateBaseTypeLocVisit(
                ASTClassName, ClassEntries, RecursionGuardParam, InheritanceMap
            )

        self.implementationContent += "}"

        self.implementationContent += """
{0} NodeIntrospection::{1} {{
  NodeLocationAccessors Result;
  SharedLocationCall Prefix;
  std::vector<clang::TypeLoc> TypeLocRecursionGuard;

  GetLocationsImpl(Prefix, Object, Result.LocationAccessors,
                   Result.RangeAccessors, TypeLocRecursionGuard);
""".format(
            MethodReturnType, Signature
        )

        self.implementationContent += "return Result; }"

    def GenerateBaseTypeLocVisit(
        self, ASTClassName, ClassEntries, RecursionGuardParam, InheritanceMap
    ):
        CallPrefix = "Prefix"
        if ASTClassName != "TypeLoc":
            CallPrefix = """llvm::makeIntrusiveRefCnt<LocationCall>(Prefix,
                    "getAs<clang::{0}>", LocationCall::IsCast)
                """.format(
                ASTClassName
            )

        if ASTClassName in ClassEntries:

            self.implementationContent += """
            if (auto ConcreteTL = Object.getAs<clang::{0}>())
              GetLocations{1}({2}, ConcreteTL, Locs, Rngs {3});
            """.format(
                ASTClassName, ASTClassName, CallPrefix, RecursionGuardParam
            )

        if ASTClassName in InheritanceMap:
            for baseTemplate in self.templateClasses:
                if baseTemplate in InheritanceMap[ASTClassName]:
                    self.implementationContent += """
    if (auto ConcreteTL = Object.getAs<clang::{0}>())
      GetLocations{1}({2}, ConcreteTL, Locs, Rngs {3});
    """.format(
                        InheritanceMap[ASTClassName],
                        baseTemplate,
                        CallPrefix,
                        RecursionGuardParam,
                    )

    def GenerateDynNodeVisitor(self, CladeNames):
        MethodReturnType = "NodeLocationAccessors"

        Signature = "GetLocations(clang::DynTypedNode const &Node)"

        self.implementationContent += (
            MethodReturnType + " NodeIntrospection::" + Signature + "{"
        )

        for CladeName in CladeNames:
            if CladeName == "DeclarationNameInfo":
                continue
            self.implementationContent += """
    if (const auto *N = Node.get<{0}>())
    """.format(
                CladeName
            )
            ArgPrefix = ""
            if CladeName in self.RefClades:
                ArgPrefix = "*"
            self.implementationContent += """
      return GetLocations({0}const_cast<{1} *>(N));""".format(
                ArgPrefix, CladeName
            )

        self.implementationContent += "\nreturn {}; }"

    def GenerateEpilogue(self):

        self.implementationContent += """
  }
}
"""


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--json-input-path", help="Read API description from FILE", metavar="FILE"
    )
    parser.add_argument(
        "--output-file", help="Generate output in FILEPATH", metavar="FILEPATH"
    )
    parser.add_argument(
        "--use-empty-implementation",
        help="Generate empty implementation",
        action="store",
        type=int,
    )
    parser.add_argument(
        "--empty-implementation",
        help="Copy empty implementation from FILEPATH",
        action="store",
        metavar="FILEPATH",
    )

    options = parser.parse_args()

    use_empty_implementation = options.use_empty_implementation

    if not use_empty_implementation and not os.path.exists(options.json_input_path):
        use_empty_implementation = True

    if not use_empty_implementation:
        with open(options.json_input_path) as f:
            jsonData = json.load(f)

        if not "classesInClade" in jsonData or not jsonData["classesInClade"]:
            use_empty_implementation = True

    if use_empty_implementation:
        if not os.path.exists(options.output_file) or not filecmp.cmp(
            options.empty_implementation, options.output_file
        ):
            shutil.copyfile(options.empty_implementation, options.output_file)
        sys.exit(0)

    templateClasses = []
    for (ClassName, ClassAccessors) in jsonData["classEntries"].items():
        if "templateParms" in ClassAccessors:
            templateClasses.append(ClassName)

    g = Generator(templateClasses)

    g.GeneratePrologue()

    for (CladeName, ClassNameData) in jsonData["classesInClade"].items():
        g.GenerateBaseGetLocationsDeclaration(CladeName)

    def getCladeName(ClassName):
        for (CladeName, ClassNameData) in jsonData["classesInClade"].items():
            if ClassName in ClassNameData:
                return CladeName

    for (ClassName, ClassAccessors) in jsonData["classEntries"].items():
        cladeName = getCladeName(ClassName)
        g.GenerateSrcLocMethod(
            ClassName, ClassAccessors, cladeName not in Generator.RefClades
        )

    for (CladeName, ClassNameData) in jsonData["classesInClade"].items():
        g.GenerateBaseGetLocationsFunction(
            ClassNameData,
            jsonData["classEntries"],
            CladeName,
            jsonData["classInheritance"],
            CladeName not in Generator.RefClades,
        )

    g.GenerateDynNodeVisitor(jsonData["classesInClade"].keys())

    g.GenerateEpilogue()

    g.GenerateFiles(options.output_file)


if __name__ == "__main__":
    main()
